{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "networks_seq2seq_nmt.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU",
    "pycharm": {
      "stem_cell": {
        "cell_type": "raw",
        "source": [],
        "metadata": {
          "collapsed": false
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "f9ySOjrcc0Yp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bl9GdT7h0Hxk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WhwgQAn50EZp",
        "colab_type": "text"
      },
      "source": [
        "# TensorFlow Addons Networks : Sequence-to-Sequence NMT with Attention Mechanism\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/addons/blob/master/docs/tutorials/networks_seq2seq_nmt.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/addons/blob/master/docs/tutorials/networks_seq2seq_nmt.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "      <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/addons/docs/tutorials/networks_seq2seq_nmt.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ip0n8178Fuwm",
        "colab_type": "text"
      },
      "source": [
        "## Overview\n",
        "This notebook gives a brief introduction into the ***Sequence to Sequence Model Architecture***\n",
        "In this noteboook we broadly cover four essential topics necessary for Neural Machine Translation:\n",
        "\n",
        "\n",
        "* **Data cleaning**\n",
        "* **Data preparation**\n",
        "* **Neural Translation Model with Attention**\n",
        "* **Final Translation**\n",
        "\n",
        "The basic idea behind such a model though, is only the encoder-decoder architecture. These networks are usually used for a variety of tasks like text-summerization, Machine translation, Image Captioning, etc. This tutorial provideas a hands-on understanding of the concept, explaining the technical jargons wherever necessary. We focus on the task of Neural Machine Translation (NMT) which was the very first testbed for seq2seq models.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YNiadLKNLleD",
        "colab_type": "text"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "82GcQTsGf414",
        "colab_type": "text"
      },
      "source": [
        "## Additional Resources:\n",
        "\n",
        "These are a lst of resurces you must install in order to allow you to run this notebook:\n",
        "\n",
        "\n",
        "1. [German-English Dataset](http://www.manythings.org/anki/deu-eng.zip)\n",
        "\n",
        "\n",
        "The dataset should be downloaded, in order to compile this notebook, the embeddings can be used, as they are pretrained. Though, we carry out our own training here !!\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5OIlpST_6ga-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#download data\n",
        "print(\"Downloading Dataset:\")\n",
        "!wget --quiet http://www.manythings.org/anki/deu-eng.zip\n",
        "!unzip deu-eng.zip"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "co6-YpBwL-4d",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "6571961c-8f50-4333-9b1d-5eb1a157f4f8"
      },
      "source": [
        "import csv\n",
        "import string\n",
        "import re\n",
        "from typing import List, Tuple\n",
        "from pickle import dump\n",
        "from unicodedata import normalize\n",
        "import numpy as np\n",
        "import itertools\n",
        "from pickle import load\n",
        "from tensorflow.keras.utils import to_categorical\n",
        "from keras.utils.vis_utils import plot_model\n",
        "from tensorflow.keras.models import Sequential\n",
        "from tensorflow.keras.layers import LSTM\n",
        "from tensorflow.keras.layers import Dense\n",
        "from tensorflow.keras.layers import Embedding\n",
        "from pickle import load\n",
        "import random\n",
        "import tensorflow as tf\n",
        "from keras.models import load_model\n",
        "from nltk.translate.bleu_score import corpus_bleu\n",
        "from sklearn.model_selection import train_test_split\n",
        "import tensorflow_addons as tfa"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Using TensorFlow backend.\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q7gjUT_9XSoj",
        "colab_type": "text"
      },
      "source": [
        "## Data Cleaning\n",
        "\n",
        "Our data set is a German-English translation dataset. It contains 152,820 pairs of English to German phases, one pair per line with a tab separating the language. These dataset though organized needs cleaning before we can work on it. This will enable us to remove unnecessary bumps that may come in during the training. We also added start-of-sentence `<start>` and end-of-sentence `<end>` so that the model knows when to start and stop predicting."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6ZIu-TNqKFsd",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Start of sentence\n",
        "SOS = \"<start>\"\n",
        "# End of sentence\n",
        "EOS = \"<end>\"\n",
        "# Relevant punctuation\n",
        "PUNCTUATION = set(\"?,!.\")\n",
        "\n",
        "\n",
        "def load_dataset(filename: str) -> str:\n",
        "    \"\"\"\n",
        "    load dataset into memory\n",
        "    \"\"\"\n",
        "    with open(filename, mode=\"rt\", encoding=\"utf-8\") as fp:\n",
        "        return fp.read()\n",
        "\n",
        "\n",
        "def to_pairs(dataset: str, limit: int = None, shuffle=False) -> List[Tuple[str, str]]:\n",
        "    \"\"\"\n",
        "    Split dataset into pairs of sentences, discards dataset line info.\n",
        "\n",
        "    e.g.\n",
        "    input -> 'Go.\\tGeh.\\tCC-BY 2.0 (France) Attribution: tatoeba.org\n",
        "    #2877272 (CM) & #8597805 (Roujin)'\n",
        "    output -> [('Go.', 'Geh.')]\n",
        "\n",
        "    :param dataset: dataset containing examples of translations between\n",
        "    two languages\n",
        "    the examples are delimited by `\\n` and the contents of the lines are\n",
        "    delimited by `\\t`\n",
        "    :param limit: number that limit dataset size (optional)\n",
        "    :param shuffle: default is True\n",
        "    :return: list of pairs\n",
        "    \"\"\"\n",
        "    assert isinstance(limit, (int, type(None))), TypeError(\n",
        "        \"the limit value must be an integer\"\n",
        "    )\n",
        "    lines = dataset.strip().split(\"\\n\")\n",
        "    # Radom dataset\n",
        "    if shuffle is True:\n",
        "        random.shuffle(lines)\n",
        "    number_examples = limit or len(lines)  # if None get all\n",
        "    pairs = []\n",
        "    for line in lines[: abs(number_examples)]:\n",
        "        # take only source and target\n",
        "        src, trg, _ = line.split(\"\\t\")\n",
        "        pairs.append((src, trg))\n",
        "\n",
        "    # dataset size check\n",
        "    assert len(pairs) == number_examples\n",
        "    return pairs\n",
        "\n",
        "\n",
        "def separe_punctuation(token: str) -> str:\n",
        "    \"\"\"\n",
        "    Separe punctuation if exists\n",
        "    \"\"\"\n",
        "\n",
        "    if not set(token).intersection(PUNCTUATION):\n",
        "        return token\n",
        "    for p in PUNCTUATION:\n",
        "        token = f\" {p} \".join(token.split(p))\n",
        "    return \" \".join(token.split())\n",
        "\n",
        "\n",
        "def preprocess(sentence: str, add_start_end: bool=True) -> str:\n",
        "    \"\"\"\n",
        "    - convert lowercase\n",
        "    - remove numbers\n",
        "    - remove special characters\n",
        "    - separe punctuation\n",
        "    - add start-of-sentence <start> and end-of-sentence <end>\n",
        "\n",
        "    :param add_start_end: add SOS (start-of-sentence) and EOS (end-of-sentence)\n",
        "    \"\"\"\n",
        "    re_print = re.compile(f\"[^{re.escape(string.printable)}]\")\n",
        "    # convert lowercase and normalizing unicode characters\n",
        "    sentence = (\n",
        "        normalize(\"NFD\", sentence.lower()).encode(\"ascii\", \"ignore\").decode(\"UTF-8\")\n",
        "    )\n",
        "    cleaned_tokens = []\n",
        "    # tokenize sentence on white space\n",
        "    for token in sentence.split():\n",
        "        # removing non-printable chars form each token\n",
        "        token = re_print.sub(\"\", token).strip()\n",
        "        # ignore tokens with numbers\n",
        "        if re.findall(\"[0-9]\", token):\n",
        "            continue\n",
        "        # add space between words and punctuation eg: \"ok?go!\" => \"ok ? go !\"\n",
        "        token = separe_punctuation(token)\n",
        "        cleaned_tokens.append(token)\n",
        "\n",
        "    # rebuild sentence with space between tokens\n",
        "    sentence = \" \".join(cleaned_tokens)\n",
        "\n",
        "    # adding a start and an end token to the sentence\n",
        "    if add_start_end is True:\n",
        "        sentence = f\"{SOS} {sentence} {EOS}\"\n",
        "    return sentence\n",
        "\n",
        "\n",
        "def dataset_preprocess(dataset: List[Tuple[str, str]]) -> Tuple[List[str], List[str]]:\n",
        "    \"\"\"\n",
        "    Returns processed database\n",
        "\n",
        "    :param dataset: list of sentence pairs\n",
        "    :return: list of paralel data e.g. \n",
        "    (['first source sentence', 'second', ...], ['first target sentence', 'second', ...])\n",
        "    \"\"\"\n",
        "    source_cleaned = []\n",
        "    target_cleaned = []\n",
        "    for source, target in dataset:\n",
        "        source_cleaned.append(preprocess(source))\n",
        "        target_cleaned.append(preprocess(target))\n",
        "    return source_cleaned, target_cleaned\n"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5nDIELt9RH-w",
        "colab_type": "text"
      },
      "source": [
        "## Create Dataset\n",
        "\n",
        "- limit number of examples\n",
        "- load dataset into pairs `[('Be nice.', 'Seien Sie nett!'), ('Beat it.', 'Geh weg!'), ...]`\n",
        "- preprocessing dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GMxdlVU1X8yI",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 119
        },
        "outputId": "f4977f48-dbe9-4323-ec2a-a9b0cf8b1895"
      },
      "source": [
        "NUM_EXAMPLES = 10000 # Limit dataset size\n",
        "\n",
        "# load from .txt\n",
        "filename = 'deu.txt' #change filename if necessary\n",
        "dataset = load_dataset(filename)\n",
        "# get pairs limited into 1000\n",
        "pairs = to_pairs(dataset, limit=NUM_EXAMPLES)\n",
        "print(f\"Dataset size: {len(pairs)}\")\n",
        "raw_data_en, raw_data_ge = dataset_preprocess(pairs)\n",
        "\n",
        "# show last 5 pairs\n",
        "for pair in zip(raw_data_en[-5:],raw_data_ge[-5:]):\n",
        "    print(pair)"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Dataset size: 10000\n",
            "(\"<start> tom's hungover . <end>\", '<start> tom ist verkatert . <end>')\n",
            "(\"<start> tom's in there . <end>\", '<start> tom ist da drinnen . <end>')\n",
            "(\"<start> tom's innocent . <end>\", '<start> tom ist unschuldig . <end>')\n",
            "(\"<start> tom's laughing . <end>\", '<start> tom lacht . <end>')\n",
            "(\"<start> tom's not busy . <end>\", '<start> tom ist nicht beschaftigt . <end>')\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cfb66QxWYr6A",
        "colab_type": "text"
      },
      "source": [
        "## Tokenization"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3oq60MBPSanQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "en_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')\n",
        "en_tokenizer.fit_on_texts(raw_data_en)\n",
        "\n",
        "data_en = en_tokenizer.texts_to_sequences(raw_data_en)\n",
        "data_en = tf.keras.preprocessing.sequence.pad_sequences(data_en,padding='post')\n",
        "\n",
        "ge_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')\n",
        "ge_tokenizer.fit_on_texts(raw_data_ge)\n",
        "\n",
        "data_ge = ge_tokenizer.texts_to_sequences(raw_data_ge)\n",
        "data_ge = tf.keras.preprocessing.sequence.pad_sequences(data_ge,padding='post')"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XH5oSRNeSc1s",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def max_len(tensor):\n",
        "    #print( np.argmax([len(t) for t in tensor]))\n",
        "    return max( len(t) for t in tensor)"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KdM37lNBGXAj",
        "colab_type": "text"
      },
      "source": [
        "## Model Parameters"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EfiBUJM2Et6C",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X_train,  X_test, Y_train, Y_test = train_test_split(data_en,data_ge,test_size=0.2)\n",
        "BATCH_SIZE = 64\n",
        "BUFFER_SIZE = len(X_train)\n",
        "steps_per_epoch = BUFFER_SIZE//BATCH_SIZE\n",
        "embedding_dims = 256\n",
        "rnn_units = 1024\n",
        "dense_units = 1024\n",
        "Dtype = tf.float32   #used to initialize DecoderCell Zero state"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ff_jQHLhGqJU",
        "colab_type": "text"
      },
      "source": [
        "## Dataset Prepration"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "b__1hPHVFALO",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        },
        "outputId": "88d35286-184c-44e7-a16b-5559f22e2eb1"
      },
      "source": [
        "Tx = max_len(data_en)\n",
        "Ty = max_len(data_ge)  \n",
        "\n",
        "input_vocab_size = len(en_tokenizer.word_index)+1  \n",
        "output_vocab_size = len(ge_tokenizer.word_index)+ 1\n",
        "dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)\n",
        "example_X, example_Y = next(iter(dataset))\n",
        "print(example_X.shape) \n",
        "print(example_Y.shape) "
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "(64, 9)\n",
            "(64, 13)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UQRgJcYgapqE",
        "colab_type": "text"
      },
      "source": [
        "## Defining NMT Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sGdakRtjaokF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#ENCODER\n",
        "class EncoderNetwork(tf.keras.Model):\n",
        "    def __init__(self,input_vocab_size,embedding_dims, rnn_units ):\n",
        "        super().__init__()\n",
        "        self.encoder_embedding = tf.keras.layers.Embedding(input_dim=input_vocab_size,\n",
        "                                                           output_dim=embedding_dims)\n",
        "        self.encoder_rnnlayer = tf.keras.layers.LSTM(rnn_units,return_sequences=True, \n",
        "                                                     return_state=True )\n",
        "    \n",
        "#DECODER\n",
        "class DecoderNetwork(tf.keras.Model):\n",
        "    def __init__(self,output_vocab_size, embedding_dims, rnn_units):\n",
        "        super().__init__()\n",
        "        self.decoder_embedding = tf.keras.layers.Embedding(input_dim=output_vocab_size,\n",
        "                                                           output_dim=embedding_dims) \n",
        "        self.dense_layer = tf.keras.layers.Dense(output_vocab_size)\n",
        "        self.decoder_rnncell = tf.keras.layers.LSTMCell(rnn_units)\n",
        "        # Sampler\n",
        "        self.sampler = tfa.seq2seq.sampler.TrainingSampler()\n",
        "        # Create attention mechanism with memory = None\n",
        "        self.attention_mechanism = self.build_attention_mechanism(dense_units,None,BATCH_SIZE*[Tx])\n",
        "        self.rnn_cell =  self.build_rnn_cell(BATCH_SIZE)\n",
        "        self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell, sampler= self.sampler,\n",
        "                                                output_layer=self.dense_layer)\n",
        "\n",
        "    def build_attention_mechanism(self, units,memory, memory_sequence_length):\n",
        "        return tfa.seq2seq.LuongAttention(units, memory = memory, \n",
        "                                          memory_sequence_length=memory_sequence_length)\n",
        "        #return tfa.seq2seq.BahdanauAttention(units, memory = memory, memory_sequence_length=memory_sequence_length)\n",
        "\n",
        "    # wrap decodernn cell  \n",
        "    def build_rnn_cell(self, batch_size ):\n",
        "        rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnncell, self.attention_mechanism,\n",
        "                                                attention_layer_size=dense_units)\n",
        "        return rnn_cell\n",
        "    \n",
        "    def build_decoder_initial_state(self, batch_size, encoder_state,Dtype):\n",
        "        decoder_initial_state = self.rnn_cell.get_initial_state(batch_size = batch_size, \n",
        "                                                                dtype = Dtype)\n",
        "        decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state) \n",
        "        return decoder_initial_state\n",
        "\n",
        "encoderNetwork = EncoderNetwork(input_vocab_size,embedding_dims, rnn_units)\n",
        "decoderNetwork = DecoderNetwork(output_vocab_size,embedding_dims, rnn_units)\n",
        "optimizer = tf.keras.optimizers.Adam()\n"
      ],
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NPwcfddTa0oB",
        "colab_type": "text"
      },
      "source": [
        "## Initializing Training functions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x1BEqVyra2jW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def loss_function(y_pred, y):\n",
        "   \n",
        "    #shape of y [batch_size, ty]\n",
        "    #shape of y_pred [batch_size, Ty, output_vocab_size] \n",
        "    sparsecategoricalcrossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,\n",
        "                                                                                  reduction='none')\n",
        "    loss = sparsecategoricalcrossentropy(y_true=y, y_pred=y_pred)\n",
        "    mask = tf.logical_not(tf.math.equal(y,0))   #output 0 for y=0 else output 1\n",
        "    mask = tf.cast(mask, dtype=loss.dtype)\n",
        "    loss = mask* loss\n",
        "    loss = tf.reduce_mean(loss)\n",
        "    return loss\n",
        "\n",
        "def train_step(input_batch, output_batch,encoder_initial_cell_state):\n",
        "    #initialize loss = 0\n",
        "    loss = 0\n",
        "    with tf.GradientTape() as tape:\n",
        "        encoder_emb_inp = encoderNetwork.encoder_embedding(input_batch)\n",
        "        a, a_tx, c_tx = encoderNetwork.encoder_rnnlayer(encoder_emb_inp, \n",
        "                                                        initial_state =encoder_initial_cell_state)\n",
        "\n",
        "        #[last step activations,last memory_state] of encoder passed as input to decoder Network\n",
        "        \n",
        "         \n",
        "        # Prepare correct Decoder input & output sequence data\n",
        "        decoder_input = output_batch[:,:-1] # ignore <end>\n",
        "        #compare logits with timestepped +1 version of decoder_input\n",
        "        decoder_output = output_batch[:,1:] #ignore <start>\n",
        "\n",
        "\n",
        "        # Decoder Embeddings\n",
        "        decoder_emb_inp = decoderNetwork.decoder_embedding(decoder_input)\n",
        "\n",
        "        #Setting up decoder memory from encoder output and Zero State for AttentionWrapperState\n",
        "        decoderNetwork.attention_mechanism.setup_memory(a)\n",
        "        decoder_initial_state = decoderNetwork.build_decoder_initial_state(BATCH_SIZE,\n",
        "                                                                           encoder_state=[a_tx, c_tx],\n",
        "                                                                           Dtype=tf.float32)\n",
        "        \n",
        "        #BasicDecoderOutput        \n",
        "        outputs, _, _ = decoderNetwork.decoder(decoder_emb_inp,initial_state=decoder_initial_state,\n",
        "                                               sequence_length=BATCH_SIZE*[Ty-1])\n",
        "\n",
        "        logits = outputs.rnn_output\n",
        "        #Calculate loss\n",
        "\n",
        "        loss = loss_function(logits, decoder_output)\n",
        "\n",
        "    #Returns the list of all layer variables / weights.\n",
        "    variables = encoderNetwork.trainable_variables + decoderNetwork.trainable_variables  \n",
        "    # differentiate loss wrt variables\n",
        "    gradients = tape.gradient(loss, variables)\n",
        "\n",
        "    #grads_and_vars – List of(gradient, variable) pairs.\n",
        "    grads_and_vars = zip(gradients,variables)\n",
        "    optimizer.apply_gradients(grads_and_vars)\n",
        "    return loss"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "71Lkdx6GFb3A",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#RNN LSTM hidden and memory state initializer\n",
        "def initialize_initial_state():\n",
        "        return [tf.zeros((BATCH_SIZE, rnn_units)), tf.zeros((BATCH_SIZE, rnn_units))]"
      ],
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v5uzLcu2bNX3",
        "colab_type": "text"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PvfD2SknWrt6",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "0a427bb7-8184-4076-97ca-f638116ca52b"
      },
      "source": [
        "epochs = 15\n",
        "for i in range(1, epochs+1):\n",
        "\n",
        "    encoder_initial_cell_state = initialize_initial_state()\n",
        "    total_loss = 0.0\n",
        "\n",
        "    for ( batch , (input_batch, output_batch)) in enumerate(dataset.take(steps_per_epoch)):\n",
        "        batch_loss = train_step(input_batch, output_batch, encoder_initial_cell_state)\n",
        "        total_loss += batch_loss\n",
        "        if (batch+1)%5 == 0:\n",
        "            print(\"total loss: {} epoch {} batch {} \".format(batch_loss.numpy(), i, batch+1))"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "total loss: 4.11566686630249 epoch 1 batch 5 \n",
            "total loss: 2.993711233139038 epoch 1 batch 10 \n",
            "total loss: 2.456167459487915 epoch 1 batch 15 \n",
            "total loss: 2.1430583000183105 epoch 1 batch 20 \n",
            "total loss: 2.202500820159912 epoch 1 batch 25 \n",
            "total loss: 2.0447075366973877 epoch 1 batch 30 \n",
            "total loss: 1.943502426147461 epoch 1 batch 35 \n",
            "total loss: 1.8647733926773071 epoch 1 batch 40 \n",
            "total loss: 1.887935757637024 epoch 1 batch 45 \n",
            "total loss: 2.0422816276550293 epoch 1 batch 50 \n",
            "total loss: 1.7727909088134766 epoch 1 batch 55 \n",
            "total loss: 1.64265775680542 epoch 1 batch 60 \n",
            "total loss: 1.708620548248291 epoch 1 batch 65 \n",
            "total loss: 1.663000464439392 epoch 1 batch 70 \n",
            "total loss: 1.733208179473877 epoch 1 batch 75 \n",
            "total loss: 1.6179828643798828 epoch 1 batch 80 \n",
            "total loss: 1.6496108770370483 epoch 1 batch 85 \n",
            "total loss: 1.7499841451644897 epoch 1 batch 90 \n",
            "total loss: 1.6253910064697266 epoch 1 batch 95 \n",
            "total loss: 1.6513166427612305 epoch 1 batch 100 \n",
            "total loss: 1.6768431663513184 epoch 1 batch 105 \n",
            "total loss: 1.5870885848999023 epoch 1 batch 110 \n",
            "total loss: 1.5872650146484375 epoch 1 batch 115 \n",
            "total loss: 1.6291579008102417 epoch 1 batch 120 \n",
            "total loss: 1.5899494886398315 epoch 1 batch 125 \n",
            "total loss: 1.4799423217773438 epoch 2 batch 5 \n",
            "total loss: 1.5262809991836548 epoch 2 batch 10 \n",
            "total loss: 1.5344295501708984 epoch 2 batch 15 \n",
            "total loss: 1.4018179178237915 epoch 2 batch 20 \n",
            "total loss: 1.2517988681793213 epoch 2 batch 25 \n",
            "total loss: 1.3529373407363892 epoch 2 batch 30 \n",
            "total loss: 1.3586145639419556 epoch 2 batch 35 \n",
            "total loss: 1.445724606513977 epoch 2 batch 40 \n",
            "total loss: 1.3398317098617554 epoch 2 batch 45 \n",
            "total loss: 1.3680673837661743 epoch 2 batch 50 \n",
            "total loss: 1.2311376333236694 epoch 2 batch 55 \n",
            "total loss: 1.4052680730819702 epoch 2 batch 60 \n",
            "total loss: 1.333901286125183 epoch 2 batch 65 \n",
            "total loss: 1.3520702123641968 epoch 2 batch 70 \n",
            "total loss: 1.3728466033935547 epoch 2 batch 75 \n",
            "total loss: 1.2714239358901978 epoch 2 batch 80 \n",
            "total loss: 1.2564586400985718 epoch 2 batch 85 \n",
            "total loss: 1.3137339353561401 epoch 2 batch 90 \n",
            "total loss: 1.2975138425827026 epoch 2 batch 95 \n",
            "total loss: 1.387586236000061 epoch 2 batch 100 \n",
            "total loss: 1.3724294900894165 epoch 2 batch 105 \n",
            "total loss: 1.2119255065917969 epoch 2 batch 110 \n",
            "total loss: 1.3122206926345825 epoch 2 batch 115 \n",
            "total loss: 1.2198586463928223 epoch 2 batch 120 \n",
            "total loss: 1.3899301290512085 epoch 2 batch 125 \n",
            "total loss: 1.165706753730774 epoch 3 batch 5 \n",
            "total loss: 1.1252425909042358 epoch 3 batch 10 \n",
            "total loss: 1.238714575767517 epoch 3 batch 15 \n",
            "total loss: 1.1738801002502441 epoch 3 batch 20 \n",
            "total loss: 1.0969573259353638 epoch 3 batch 25 \n",
            "total loss: 1.333953857421875 epoch 3 batch 30 \n",
            "total loss: 1.150842547416687 epoch 3 batch 35 \n",
            "total loss: 1.188685417175293 epoch 3 batch 40 \n",
            "total loss: 1.1465986967086792 epoch 3 batch 45 \n",
            "total loss: 1.192483901977539 epoch 3 batch 50 \n",
            "total loss: 1.0624631643295288 epoch 3 batch 55 \n",
            "total loss: 1.0510236024856567 epoch 3 batch 60 \n",
            "total loss: 1.2240933179855347 epoch 3 batch 65 \n",
            "total loss: 1.1826471090316772 epoch 3 batch 70 \n",
            "total loss: 1.1488090753555298 epoch 3 batch 75 \n",
            "total loss: 1.0666199922561646 epoch 3 batch 80 \n",
            "total loss: 1.099870204925537 epoch 3 batch 85 \n",
            "total loss: 1.2090061902999878 epoch 3 batch 90 \n",
            "total loss: 1.0520261526107788 epoch 3 batch 95 \n",
            "total loss: 1.1468778848648071 epoch 3 batch 100 \n",
            "total loss: 1.1265020370483398 epoch 3 batch 105 \n",
            "total loss: 1.1606225967407227 epoch 3 batch 110 \n",
            "total loss: 1.0110392570495605 epoch 3 batch 115 \n",
            "total loss: 1.0859214067459106 epoch 3 batch 120 \n",
            "total loss: 1.0578597784042358 epoch 3 batch 125 \n",
            "total loss: 0.9532763957977295 epoch 4 batch 5 \n",
            "total loss: 0.9986910223960876 epoch 4 batch 10 \n",
            "total loss: 0.956193208694458 epoch 4 batch 15 \n",
            "total loss: 0.9690749049186707 epoch 4 batch 20 \n",
            "total loss: 1.101245403289795 epoch 4 batch 25 \n",
            "total loss: 0.9993131160736084 epoch 4 batch 30 \n",
            "total loss: 0.9002986550331116 epoch 4 batch 35 \n",
            "total loss: 0.9263710975646973 epoch 4 batch 40 \n",
            "total loss: 0.9197303652763367 epoch 4 batch 45 \n",
            "total loss: 0.928395688533783 epoch 4 batch 50 \n",
            "total loss: 1.0114047527313232 epoch 4 batch 55 \n",
            "total loss: 1.083633542060852 epoch 4 batch 60 \n",
            "total loss: 0.9597204327583313 epoch 4 batch 65 \n",
            "total loss: 0.948369562625885 epoch 4 batch 70 \n",
            "total loss: 0.9748582243919373 epoch 4 batch 75 \n",
            "total loss: 1.1318320035934448 epoch 4 batch 80 \n",
            "total loss: 0.9337785243988037 epoch 4 batch 85 \n",
            "total loss: 1.066165566444397 epoch 4 batch 90 \n",
            "total loss: 0.896867573261261 epoch 4 batch 95 \n",
            "total loss: 0.8608654141426086 epoch 4 batch 100 \n",
            "total loss: 1.0241042375564575 epoch 4 batch 105 \n",
            "total loss: 0.9655657410621643 epoch 4 batch 110 \n",
            "total loss: 0.9644956588745117 epoch 4 batch 115 \n",
            "total loss: 0.9764884114265442 epoch 4 batch 120 \n",
            "total loss: 0.9435749650001526 epoch 4 batch 125 \n",
            "total loss: 0.8453261852264404 epoch 5 batch 5 \n",
            "total loss: 0.8299359679222107 epoch 5 batch 10 \n",
            "total loss: 0.7513504028320312 epoch 5 batch 15 \n",
            "total loss: 0.9070504307746887 epoch 5 batch 20 \n",
            "total loss: 0.8284425139427185 epoch 5 batch 25 \n",
            "total loss: 0.7402708530426025 epoch 5 batch 30 \n",
            "total loss: 0.8092263340950012 epoch 5 batch 35 \n",
            "total loss: 0.8729425072669983 epoch 5 batch 40 \n",
            "total loss: 0.8656869530677795 epoch 5 batch 45 \n",
            "total loss: 0.7958595752716064 epoch 5 batch 50 \n",
            "total loss: 0.8578026294708252 epoch 5 batch 55 \n",
            "total loss: 0.778644859790802 epoch 5 batch 60 \n",
            "total loss: 0.7277960777282715 epoch 5 batch 65 \n",
            "total loss: 0.8738289475440979 epoch 5 batch 70 \n",
            "total loss: 0.6856063008308411 epoch 5 batch 75 \n",
            "total loss: 0.8267806172370911 epoch 5 batch 80 \n",
            "total loss: 0.946643054485321 epoch 5 batch 85 \n",
            "total loss: 0.9214975237846375 epoch 5 batch 90 \n",
            "total loss: 0.796623706817627 epoch 5 batch 95 \n",
            "total loss: 0.8234322667121887 epoch 5 batch 100 \n",
            "total loss: 0.8310582041740417 epoch 5 batch 105 \n",
            "total loss: 0.7187271118164062 epoch 5 batch 110 \n",
            "total loss: 0.881775438785553 epoch 5 batch 115 \n",
            "total loss: 0.8496475219726562 epoch 5 batch 120 \n",
            "total loss: 0.7749930024147034 epoch 5 batch 125 \n",
            "total loss: 0.6057237386703491 epoch 6 batch 5 \n",
            "total loss: 0.5688410401344299 epoch 6 batch 10 \n",
            "total loss: 0.6365471482276917 epoch 6 batch 15 \n",
            "total loss: 0.6626251935958862 epoch 6 batch 20 \n",
            "total loss: 0.6636946797370911 epoch 6 batch 25 \n",
            "total loss: 0.6313133835792542 epoch 6 batch 30 \n",
            "total loss: 0.5917147397994995 epoch 6 batch 35 \n",
            "total loss: 0.6965726017951965 epoch 6 batch 40 \n",
            "total loss: 0.6281453371047974 epoch 6 batch 45 \n",
            "total loss: 0.6475895047187805 epoch 6 batch 50 \n",
            "total loss: 0.7765358090400696 epoch 6 batch 55 \n",
            "total loss: 0.5973318219184875 epoch 6 batch 60 \n",
            "total loss: 0.713416576385498 epoch 6 batch 65 \n",
            "total loss: 0.7173630595207214 epoch 6 batch 70 \n",
            "total loss: 0.7002382874488831 epoch 6 batch 75 \n",
            "total loss: 0.6431768536567688 epoch 6 batch 80 \n",
            "total loss: 0.6381948590278625 epoch 6 batch 85 \n",
            "total loss: 0.7046375870704651 epoch 6 batch 90 \n",
            "total loss: 0.6564927697181702 epoch 6 batch 95 \n",
            "total loss: 0.7156146168708801 epoch 6 batch 100 \n",
            "total loss: 0.7078973650932312 epoch 6 batch 105 \n",
            "total loss: 0.6482166647911072 epoch 6 batch 110 \n",
            "total loss: 0.5653694868087769 epoch 6 batch 115 \n",
            "total loss: 0.768178403377533 epoch 6 batch 120 \n",
            "total loss: 0.6993356347084045 epoch 6 batch 125 \n",
            "total loss: 0.4355561435222626 epoch 7 batch 5 \n",
            "total loss: 0.5312787294387817 epoch 7 batch 10 \n",
            "total loss: 0.5179179906845093 epoch 7 batch 15 \n",
            "total loss: 0.5177888870239258 epoch 7 batch 20 \n",
            "total loss: 0.5274668335914612 epoch 7 batch 25 \n",
            "total loss: 0.4485582113265991 epoch 7 batch 30 \n",
            "total loss: 0.5205077528953552 epoch 7 batch 35 \n",
            "total loss: 0.6028087735176086 epoch 7 batch 40 \n",
            "total loss: 0.4433538615703583 epoch 7 batch 45 \n",
            "total loss: 0.5281521677970886 epoch 7 batch 50 \n",
            "total loss: 0.5123710036277771 epoch 7 batch 55 \n",
            "total loss: 0.4892776906490326 epoch 7 batch 60 \n",
            "total loss: 0.5777449011802673 epoch 7 batch 65 \n",
            "total loss: 0.5938393473625183 epoch 7 batch 70 \n",
            "total loss: 0.5447298884391785 epoch 7 batch 75 \n",
            "total loss: 0.5399925112724304 epoch 7 batch 80 \n",
            "total loss: 0.549943745136261 epoch 7 batch 85 \n",
            "total loss: 0.5606051683425903 epoch 7 batch 90 \n",
            "total loss: 0.6317020058631897 epoch 7 batch 95 \n",
            "total loss: 0.5499157309532166 epoch 7 batch 100 \n",
            "total loss: 0.5369137525558472 epoch 7 batch 105 \n",
            "total loss: 0.6119964718818665 epoch 7 batch 110 \n",
            "total loss: 0.6122032403945923 epoch 7 batch 115 \n",
            "total loss: 0.6180634498596191 epoch 7 batch 120 \n",
            "total loss: 0.5060015320777893 epoch 7 batch 125 \n",
            "total loss: 0.4102749526500702 epoch 8 batch 5 \n",
            "total loss: 0.4113573729991913 epoch 8 batch 10 \n",
            "total loss: 0.34586894512176514 epoch 8 batch 15 \n",
            "total loss: 0.4162067174911499 epoch 8 batch 20 \n",
            "total loss: 0.4488414227962494 epoch 8 batch 25 \n",
            "total loss: 0.47596967220306396 epoch 8 batch 30 \n",
            "total loss: 0.43868470191955566 epoch 8 batch 35 \n",
            "total loss: 0.4669533669948578 epoch 8 batch 40 \n",
            "total loss: 0.4095423221588135 epoch 8 batch 45 \n",
            "total loss: 0.4171658754348755 epoch 8 batch 50 \n",
            "total loss: 0.41935643553733826 epoch 8 batch 55 \n",
            "total loss: 0.42487478256225586 epoch 8 batch 60 \n",
            "total loss: 0.5020427107810974 epoch 8 batch 65 \n",
            "total loss: 0.46865570545196533 epoch 8 batch 70 \n",
            "total loss: 0.48575273156166077 epoch 8 batch 75 \n",
            "total loss: 0.402313232421875 epoch 8 batch 80 \n",
            "total loss: 0.5250392556190491 epoch 8 batch 85 \n",
            "total loss: 0.5152303576469421 epoch 8 batch 90 \n",
            "total loss: 0.4697692394256592 epoch 8 batch 95 \n",
            "total loss: 0.4108094274997711 epoch 8 batch 100 \n",
            "total loss: 0.4215029776096344 epoch 8 batch 105 \n",
            "total loss: 0.43752169609069824 epoch 8 batch 110 \n",
            "total loss: 0.45470383763313293 epoch 8 batch 115 \n",
            "total loss: 0.5394885540008545 epoch 8 batch 120 \n",
            "total loss: 0.46421656012535095 epoch 8 batch 125 \n",
            "total loss: 0.38278815150260925 epoch 9 batch 5 \n",
            "total loss: 0.3325278162956238 epoch 9 batch 10 \n",
            "total loss: 0.25561612844467163 epoch 9 batch 15 \n",
            "total loss: 0.39196979999542236 epoch 9 batch 20 \n",
            "total loss: 0.3144271671772003 epoch 9 batch 25 \n",
            "total loss: 0.3374980390071869 epoch 9 batch 30 \n",
            "total loss: 0.3220641613006592 epoch 9 batch 35 \n",
            "total loss: 0.28498175740242004 epoch 9 batch 40 \n",
            "total loss: 0.34717854857444763 epoch 9 batch 45 \n",
            "total loss: 0.27360835671424866 epoch 9 batch 50 \n",
            "total loss: 0.34681805968284607 epoch 9 batch 55 \n",
            "total loss: 0.34650543332099915 epoch 9 batch 60 \n",
            "total loss: 0.377156525850296 epoch 9 batch 65 \n",
            "total loss: 0.3942091464996338 epoch 9 batch 70 \n",
            "total loss: 0.40023937821388245 epoch 9 batch 75 \n",
            "total loss: 0.3928321301937103 epoch 9 batch 80 \n",
            "total loss: 0.3811839818954468 epoch 9 batch 85 \n",
            "total loss: 0.3996661901473999 epoch 9 batch 90 \n",
            "total loss: 0.4434957504272461 epoch 9 batch 95 \n",
            "total loss: 0.36710819602012634 epoch 9 batch 100 \n",
            "total loss: 0.4244243800640106 epoch 9 batch 105 \n",
            "total loss: 0.39385613799095154 epoch 9 batch 110 \n",
            "total loss: 0.40314915776252747 epoch 9 batch 115 \n",
            "total loss: 0.38281798362731934 epoch 9 batch 120 \n",
            "total loss: 0.34032365679740906 epoch 9 batch 125 \n",
            "total loss: 0.25806427001953125 epoch 10 batch 5 \n",
            "total loss: 0.2550051212310791 epoch 10 batch 10 \n",
            "total loss: 0.23162484169006348 epoch 10 batch 15 \n",
            "total loss: 0.26205796003341675 epoch 10 batch 20 \n",
            "total loss: 0.2918882369995117 epoch 10 batch 25 \n",
            "total loss: 0.28667184710502625 epoch 10 batch 30 \n",
            "total loss: 0.30746373534202576 epoch 10 batch 35 \n",
            "total loss: 0.24943065643310547 epoch 10 batch 40 \n",
            "total loss: 0.24033032357692719 epoch 10 batch 45 \n",
            "total loss: 0.29537105560302734 epoch 10 batch 50 \n",
            "total loss: 0.3474333584308624 epoch 10 batch 55 \n",
            "total loss: 0.31821370124816895 epoch 10 batch 60 \n",
            "total loss: 0.35506772994995117 epoch 10 batch 65 \n",
            "total loss: 0.40117380023002625 epoch 10 batch 70 \n",
            "total loss: 0.2801777422428131 epoch 10 batch 75 \n",
            "total loss: 0.26276424527168274 epoch 10 batch 80 \n",
            "total loss: 0.33141613006591797 epoch 10 batch 85 \n",
            "total loss: 0.2891913056373596 epoch 10 batch 90 \n",
            "total loss: 0.34682735800743103 epoch 10 batch 95 \n",
            "total loss: 0.39360567927360535 epoch 10 batch 100 \n",
            "total loss: 0.40213945508003235 epoch 10 batch 105 \n",
            "total loss: 0.2949744462966919 epoch 10 batch 110 \n",
            "total loss: 0.27941974997520447 epoch 10 batch 115 \n",
            "total loss: 0.28911301493644714 epoch 10 batch 120 \n",
            "total loss: 0.3066214621067047 epoch 10 batch 125 \n",
            "total loss: 0.24125127494335175 epoch 11 batch 5 \n",
            "total loss: 0.20186877250671387 epoch 11 batch 10 \n",
            "total loss: 0.2145632952451706 epoch 11 batch 15 \n",
            "total loss: 0.23457588255405426 epoch 11 batch 20 \n",
            "total loss: 0.2408994436264038 epoch 11 batch 25 \n",
            "total loss: 0.1797456294298172 epoch 11 batch 30 \n",
            "total loss: 0.19768892228603363 epoch 11 batch 35 \n",
            "total loss: 0.21545785665512085 epoch 11 batch 40 \n",
            "total loss: 0.23571373522281647 epoch 11 batch 45 \n",
            "total loss: 0.25327250361442566 epoch 11 batch 50 \n",
            "total loss: 0.2649385631084442 epoch 11 batch 55 \n",
            "total loss: 0.30291682481765747 epoch 11 batch 60 \n",
            "total loss: 0.2986145317554474 epoch 11 batch 65 \n",
            "total loss: 0.20132605731487274 epoch 11 batch 70 \n",
            "total loss: 0.24036192893981934 epoch 11 batch 75 \n",
            "total loss: 0.29774945974349976 epoch 11 batch 80 \n",
            "total loss: 0.24990446865558624 epoch 11 batch 85 \n",
            "total loss: 0.27169445157051086 epoch 11 batch 90 \n",
            "total loss: 0.2602415978908539 epoch 11 batch 95 \n",
            "total loss: 0.26800140738487244 epoch 11 batch 100 \n",
            "total loss: 0.27735427021980286 epoch 11 batch 105 \n",
            "total loss: 0.26872459053993225 epoch 11 batch 110 \n",
            "total loss: 0.27796170115470886 epoch 11 batch 115 \n",
            "total loss: 0.3037005364894867 epoch 11 batch 120 \n",
            "total loss: 0.3586186468601227 epoch 11 batch 125 \n",
            "total loss: 0.1714320331811905 epoch 12 batch 5 \n",
            "total loss: 0.17246638238430023 epoch 12 batch 10 \n",
            "total loss: 0.2304478883743286 epoch 12 batch 15 \n",
            "total loss: 0.20280466973781586 epoch 12 batch 20 \n",
            "total loss: 0.1980326622724533 epoch 12 batch 25 \n",
            "total loss: 0.24768061935901642 epoch 12 batch 30 \n",
            "total loss: 0.17778398096561432 epoch 12 batch 35 \n",
            "total loss: 0.2015562802553177 epoch 12 batch 40 \n",
            "total loss: 0.1770702451467514 epoch 12 batch 45 \n",
            "total loss: 0.2334766387939453 epoch 12 batch 50 \n",
            "total loss: 0.20495925843715668 epoch 12 batch 55 \n",
            "total loss: 0.21376878023147583 epoch 12 batch 60 \n",
            "total loss: 0.24144266545772552 epoch 12 batch 65 \n",
            "total loss: 0.2306946963071823 epoch 12 batch 70 \n",
            "total loss: 0.23844711482524872 epoch 12 batch 75 \n",
            "total loss: 0.24324734508991241 epoch 12 batch 80 \n",
            "total loss: 0.1984959989786148 epoch 12 batch 85 \n",
            "total loss: 0.2658829689025879 epoch 12 batch 90 \n",
            "total loss: 0.24130244553089142 epoch 12 batch 95 \n",
            "total loss: 0.23028753697872162 epoch 12 batch 100 \n",
            "total loss: 0.27955183386802673 epoch 12 batch 105 \n",
            "total loss: 0.269803524017334 epoch 12 batch 110 \n",
            "total loss: 0.24687449634075165 epoch 12 batch 115 \n",
            "total loss: 0.2637614905834198 epoch 12 batch 120 \n",
            "total loss: 0.2655775249004364 epoch 12 batch 125 \n",
            "total loss: 0.1553117036819458 epoch 13 batch 5 \n",
            "total loss: 0.12917208671569824 epoch 13 batch 10 \n",
            "total loss: 0.23377186059951782 epoch 13 batch 15 \n",
            "total loss: 0.17143402993679047 epoch 13 batch 20 \n",
            "total loss: 0.19789159297943115 epoch 13 batch 25 \n",
            "total loss: 0.17325706779956818 epoch 13 batch 30 \n",
            "total loss: 0.1461445689201355 epoch 13 batch 35 \n",
            "total loss: 0.1638738512992859 epoch 13 batch 40 \n",
            "total loss: 0.23124034702777863 epoch 13 batch 45 \n",
            "total loss: 0.19878023862838745 epoch 13 batch 50 \n",
            "total loss: 0.1812722235918045 epoch 13 batch 55 \n",
            "total loss: 0.24695098400115967 epoch 13 batch 60 \n",
            "total loss: 0.15736332535743713 epoch 13 batch 65 \n",
            "total loss: 0.18134035170078278 epoch 13 batch 70 \n",
            "total loss: 0.20316295325756073 epoch 13 batch 75 \n",
            "total loss: 0.17294305562973022 epoch 13 batch 80 \n",
            "total loss: 0.2048470824956894 epoch 13 batch 85 \n",
            "total loss: 0.1972559690475464 epoch 13 batch 90 \n",
            "total loss: 0.20555488765239716 epoch 13 batch 95 \n",
            "total loss: 0.15902088582515717 epoch 13 batch 100 \n",
            "total loss: 0.27476567029953003 epoch 13 batch 105 \n",
            "total loss: 0.24714398384094238 epoch 13 batch 110 \n",
            "total loss: 0.25630465149879456 epoch 13 batch 115 \n",
            "total loss: 0.269127756357193 epoch 13 batch 120 \n",
            "total loss: 0.23399965465068817 epoch 13 batch 125 \n",
            "total loss: 0.14865447580814362 epoch 14 batch 5 \n",
            "total loss: 0.16153651475906372 epoch 14 batch 10 \n",
            "total loss: 0.17261719703674316 epoch 14 batch 15 \n",
            "total loss: 0.22619158029556274 epoch 14 batch 20 \n",
            "total loss: 0.13681507110595703 epoch 14 batch 25 \n",
            "total loss: 0.16032403707504272 epoch 14 batch 30 \n",
            "total loss: 0.14292384684085846 epoch 14 batch 35 \n",
            "total loss: 0.13681446015834808 epoch 14 batch 40 \n",
            "total loss: 0.18409228324890137 epoch 14 batch 45 \n",
            "total loss: 0.1674126237630844 epoch 14 batch 50 \n",
            "total loss: 0.14732179045677185 epoch 14 batch 55 \n",
            "total loss: 0.13022463023662567 epoch 14 batch 60 \n",
            "total loss: 0.18770740926265717 epoch 14 batch 65 \n",
            "total loss: 0.16499507427215576 epoch 14 batch 70 \n",
            "total loss: 0.13566173613071442 epoch 14 batch 75 \n",
            "total loss: 0.15898260474205017 epoch 14 batch 80 \n",
            "total loss: 0.16641056537628174 epoch 14 batch 85 \n",
            "total loss: 0.1944132298231125 epoch 14 batch 90 \n",
            "total loss: 0.2262207269668579 epoch 14 batch 95 \n",
            "total loss: 0.20676560699939728 epoch 14 batch 100 \n",
            "total loss: 0.2102840393781662 epoch 14 batch 105 \n",
            "total loss: 0.19340692460536957 epoch 14 batch 110 \n",
            "total loss: 0.187296524643898 epoch 14 batch 115 \n",
            "total loss: 0.17335641384124756 epoch 14 batch 120 \n",
            "total loss: 0.2099289447069168 epoch 14 batch 125 \n",
            "total loss: 0.14340081810951233 epoch 15 batch 5 \n",
            "total loss: 0.14579172432422638 epoch 15 batch 10 \n",
            "total loss: 0.1293977051973343 epoch 15 batch 15 \n",
            "total loss: 0.15074902772903442 epoch 15 batch 20 \n",
            "total loss: 0.13329613208770752 epoch 15 batch 25 \n",
            "total loss: 0.1491243988275528 epoch 15 batch 30 \n",
            "total loss: 0.14245960116386414 epoch 15 batch 35 \n",
            "total loss: 0.14042304456233978 epoch 15 batch 40 \n",
            "total loss: 0.17087322473526 epoch 15 batch 45 \n",
            "total loss: 0.18867500126361847 epoch 15 batch 50 \n",
            "total loss: 0.17223608493804932 epoch 15 batch 55 \n",
            "total loss: 0.16629959642887115 epoch 15 batch 60 \n",
            "total loss: 0.15043802559375763 epoch 15 batch 65 \n",
            "total loss: 0.16201333701610565 epoch 15 batch 70 \n",
            "total loss: 0.1867101788520813 epoch 15 batch 75 \n",
            "total loss: 0.17749939858913422 epoch 15 batch 80 \n",
            "total loss: 0.18169927597045898 epoch 15 batch 85 \n",
            "total loss: 0.18131349980831146 epoch 15 batch 90 \n",
            "total loss: 0.18957491219043732 epoch 15 batch 95 \n",
            "total loss: 0.15851835906505585 epoch 15 batch 100 \n",
            "total loss: 0.15743960440158844 epoch 15 batch 105 \n",
            "total loss: 0.22563040256500244 epoch 15 batch 110 \n",
            "total loss: 0.17509043216705322 epoch 15 batch 115 \n",
            "total loss: 0.16400296986103058 epoch 15 batch 120 \n",
            "total loss: 0.20385797321796417 epoch 15 batch 125 \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nDyK-EGqbN5r",
        "colab_type": "text"
      },
      "source": [
        "## Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "y98sfom7SuGy",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 326
        },
        "outputId": "00d94338-e841-4bd6-f9e3-509ef1f1a08b"
      },
      "source": [
        "#In this section we evaluate our model on a raw_input converted to german, for this the entire sentence has to be passed\n",
        "#through the length of the model, for this we use greedsampler to run through the decoder\n",
        "#and the final embedding matrix trained on the data is used to generate embeddings\n",
        "input_raw='how are you'\n",
        "\n",
        "# We have a transcript file containing English-German pairs\n",
        "# Preprocess X\n",
        "input_raw = preprocess(input_raw, add_start_end=False)\n",
        "input_lines = [f'{SOS} {input_raw}']\n",
        "input_sequences = [[en_tokenizer.word_index[w] for w in line.split()] for line in input_lines]\n",
        "input_sequences = tf.keras.preprocessing.sequence.pad_sequences(input_sequences,\n",
        "                                                                maxlen=Tx, padding='post')\n",
        "inp = tf.convert_to_tensor(input_sequences)\n",
        "#print(inp.shape)\n",
        "inference_batch_size = input_sequences.shape[0]\n",
        "encoder_initial_cell_state = [tf.zeros((inference_batch_size, rnn_units)),\n",
        "                              tf.zeros((inference_batch_size, rnn_units))]\n",
        "encoder_emb_inp = encoderNetwork.encoder_embedding(inp)\n",
        "a, a_tx, c_tx = encoderNetwork.encoder_rnnlayer(encoder_emb_inp,\n",
        "                                                initial_state =encoder_initial_cell_state)\n",
        "print('a_tx :', a_tx.shape)\n",
        "print('c_tx :', c_tx.shape)\n",
        "\n",
        "start_tokens = tf.fill([inference_batch_size],ge_tokenizer.word_index[SOS])\n",
        "\n",
        "end_token = ge_tokenizer.word_index[EOS]\n",
        "\n",
        "greedy_sampler = tfa.seq2seq.GreedyEmbeddingSampler()\n",
        "\n",
        "decoder_input = tf.expand_dims([ge_tokenizer.word_index[SOS]]* inference_batch_size,1)\n",
        "decoder_emb_inp = decoderNetwork.decoder_embedding(decoder_input)\n",
        "\n",
        "decoder_instance = tfa.seq2seq.BasicDecoder(cell = decoderNetwork.rnn_cell, sampler = greedy_sampler,\n",
        "                                            output_layer=decoderNetwork.dense_layer)\n",
        "decoderNetwork.attention_mechanism.setup_memory(a)\n",
        "#pass [ last step activations , encoder memory_state ] as input to decoder for LSTM\n",
        "print(f\"decoder_initial_state = [a_tx, c_tx] : {np.array([a_tx, c_tx]).shape}\")\n",
        "decoder_initial_state = decoderNetwork.build_decoder_initial_state(inference_batch_size,\n",
        "                                                                   encoder_state=[a_tx, c_tx],\n",
        "                                                                   Dtype=tf.float32)\n",
        "print(f\"\"\"\n",
        "Compared to simple encoder-decoder without attention, the decoder_initial_state\n",
        "is an AttentionWrapperState object containing s_prev tensors and context and alignment vector\n",
        "\n",
        "decoder initial state shape: {np.array(decoder_initial_state).shape}\n",
        "decoder_initial_state tensor\n",
        "{decoder_initial_state}\n",
        "\"\"\")\n",
        "\n",
        "# Since we do not know the target sequence lengths in advance, we use maximum_iterations to limit the translation lengths.\n",
        "# One heuristic is to decode up to two times the source sentence lengths.\n",
        "maximum_iterations = tf.round(tf.reduce_max(Tx) * 2)\n",
        "\n",
        "#initialize inference decoder\n",
        "decoder_embedding_matrix = decoderNetwork.decoder_embedding.variables[0] \n",
        "(first_finished, first_inputs,first_state) = decoder_instance.initialize(decoder_embedding_matrix,\n",
        "                             start_tokens = start_tokens,\n",
        "                             end_token=end_token,\n",
        "                             initial_state = decoder_initial_state)\n",
        "#print( first_finished.shape)\n",
        "print(f\"first_inputs returns the same decoder_input i.e. embedding of  {SOS} : {first_inputs.shape}\")\n",
        "print(f\"start_index_emb_avg {tf.reduce_sum(tf.reduce_mean(first_inputs, axis=0))}\") # mean along the batch\n",
        "\n",
        "inputs = first_inputs\n",
        "state = first_state  \n",
        "predictions = np.empty((inference_batch_size,0), dtype = np.int32)                                                                             \n",
        "for j in range(maximum_iterations):\n",
        "    outputs, next_state, next_inputs, finished = decoder_instance.step(j,inputs,state)\n",
        "    inputs = next_inputs\n",
        "    state = next_state\n",
        "    outputs = np.expand_dims(outputs.sample_id,axis = -1)\n",
        "    predictions = np.append(predictions, outputs, axis = -1)"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "a_tx : (1, 1024)\n",
            "c_tx : (1, 1024)\n",
            "decoder_initial_state = [a_tx, c_tx] : (2, 1, 1024)\n",
            "\n",
            "Compared to simple encoder-decoder without attention, the decoder_initial_state\n",
            "is an AttentionWrapperState object containing s_prev tensors and context and alignment vector\n",
            "\n",
            "decoder initial state shape: (6,)\n",
            "decoder_initial_state tensor\n",
            "AttentionWrapperState(cell_state=[<tf.Tensor: shape=(1, 1024), dtype=float32, numpy=\n",
            "array([[ 0.0218722 , -0.00386145, -0.34212956, ..., -0.0818582 ,\n",
            "         0.0042587 , -0.06107492]], dtype=float32)>, <tf.Tensor: shape=(1, 1024), dtype=float32, numpy=\n",
            "array([[ 0.07267428, -0.01349923, -1.1421771 , ..., -0.27573448,\n",
            "         0.01418022, -0.14704482]], dtype=float32)>], attention=<tf.Tensor: shape=(1, 1024), dtype=float32, numpy=array([[0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>, time=<tf.Tensor: shape=(), dtype=int32, numpy=0>, alignments=<tf.Tensor: shape=(1, 9), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>, alignment_history=(), attention_state=<tf.Tensor: shape=(1, 9), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>)\n",
            "\n",
            "first_inputs returns the same decoder_input i.e. embedding of  <start> : (1, 256)\n",
            "start_index_emb_avg -1.4379956722259521\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iodjSItQds1t",
        "colab_type": "text"
      },
      "source": [
        "## Final Translation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "K6aWFB5IWlH2",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 102
        },
        "outputId": "2179c9a3-cb27-447a-ac94-0e5ab2920aff"
      },
      "source": [
        "#prediction based on our sentence earlier\n",
        "print(\"English Sentence:\")\n",
        "print(input_raw)\n",
        "print(\"\\nGerman Translation:\")\n",
        "for i in range(len(predictions)):\n",
        "    line = predictions[i,:]\n",
        "    seq = list(itertools.takewhile( lambda index: index !=2, line))\n",
        "    print(\" \".join( [ge_tokenizer.index_word[w] for w in seq]))"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "English Sentence:\n",
            "how are you\n",
            "\n",
            "German Translation:\n",
            "wie du bist !\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g6Av-oPWvRc4",
        "colab_type": "text"
      },
      "source": [
        "### The accuracy can be improved by implementing:\n",
        "* Beam Search or Lexicon Search\n",
        "* Bi-directional encoder-decoder model "
      ]
    }
  ]
}